# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Converting AST to code and Python entities.

Adapted from Tangent.
"""

import atexit
import errno
import importlib
import os
import sys
import tempfile

from nvidia.dali._autograph.pyct import origin_info
from nvidia.dali._autograph.pyct import parser


def _remove_file(file_name):
    """Remove a file, if it exists."""
    try:
        os.remove(file_name)
    except OSError as e:
        if e.errno == errno.ENOENT:
            # The file disappeared. Ignore this. Temporary files might get
            # cleaned up, especially if they reside in /tmp.
            pass
        else:
            raise


def load_source(source, delete_on_exit):
    """Loads the given source code as a Python module."""
    with tempfile.NamedTemporaryFile(
        mode="w", suffix=".py", prefix="__autograph_generated_file", delete=False, encoding="utf-8"
    ) as f:
        module_name = os.path.basename(f.name[:-3])
        file_name = f.name
        f.write(source)

    if delete_on_exit:
        atexit.register(lambda: _remove_file(file_name))

    spec = importlib.util.spec_from_file_location(module_name, file_name)
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)
    # TODO(mdan): Use our own garbage-collected cache instead of sys.modules.
    sys.modules[module_name] = module
    return module, file_name


def load_ast(nodes, indentation="  ", include_source_map=False, delete_on_exit=True):
    """Loads the given AST as a Python module.

    Compiling the AST code this way ensures that the source code is readable by
    e.g. `pdb` or `inspect`.

    Args:
      nodes: Union[ast.AST, Iterable[ast.AST]], the code to compile, as an AST
        object.
      indentation: Text, the string to use for indentation.
      include_source_map: bool, whether return a source map.
      delete_on_exit: bool, whether to delete the temporary file used for
        compilation on exit.

    Returns:
      Tuple[module, Text, Dict[LineLocation, OriginInfo]], containing:
      the module containing the unparsed nodes, the source code corresponding to
      nodes, and the source map. Is include_source_map is False, the source map
      will be None.
    """
    if not isinstance(nodes, (list, tuple)):
        nodes = (nodes,)

    source = parser.unparse(nodes, indentation=indentation)
    module, _ = load_source(source, delete_on_exit)

    if include_source_map:
        source_map = origin_info.create_source_map(nodes, source, module.__file__)
    else:
        source_map = None

    # TODO(mdan): Return a structured object.
    return module, source, source_map
